1 张量
本文摘自 知乎 M-ing大佬.
import torch
1 张量 (Tensor)
1.1 基础
# 从列表创建
torch_a = torch.tensor([[1,2],[3,4]])
# 从NumPy数组创建
torch_b = torch.tensor(np.array([[1,2],[3,4]]))
- 形状:
a.shape(返回torch.Size([2,2]));a.shape[0]: 返回第一维(行数) 2;a.shape[1]: 返回第二维(列数) 2. - 数据类型:
a.dtype. 通常来说整数列表的张量是torch.int64, 浮点数列表是torch.float32. - 设备:
device. 默认 CPU (参加 GPU运算). 两个不同设备的张量无法直接运算. - 梯度追踪:
requires_grad: 是否启用梯度追踪 (参见 自动微分)
| 数据类型 | 描述 | 取值范围/精度 |
|---|---|---|
torch.float32 |
32 位单精度浮点数 | 1e-38~1e38, 7 位精度 |
torch.float64 |
64 位双精度浮点数 | 1e-308~1e308, 15 位精度 |
torch.float16 |
16 位半精度浮点数 | 适合 GPU 加速, 精度较低 |
torch.bfloat16 |
16 位脑浮点数 | 动态范围类似 float 32 |
torch.int8 |
8 位有符号整数 | [-128,127] |
torch.uint8 |
8 位无符号整数 | [0,255] |
torch.int16 |
16 位有符号整数 | [-32768,32767] |
torch.int32 |
32 位有符号整数 | [-2^31,2^31-1] |
torch.int64 |
64 位有符号整数 | [-2^63,2^63-1] |
torch.bool |
布尔值 | True/False |
-
创建指定数据类型:
torch.tensor([[1,2],[3,4]], dtype=torch.float16).如果传入整数指定浮点类型, 整数会被转为浮点数; 反之, 会截断小数部分.
-
改变数据类型:
.to(torch.float32).
1.2 常用创建与操作
| 函数 | 说明 | 用例 |
|---|---|---|
torch.ones() |
全 1 张量 | a = torch.ones((2,3)) |
torch.zeros() |
全 0 张量 | torch.zeros((1,8)); torch.zeros(8) 生成 (8,) 形状的一维张量[^1] |
torch.ones_like() |
创建与某个张量形状相同的全 1 张量 | torch.ones_like(a) |
torch.zeros_like() |
创建与某个张量形状相同的全 0 张量 | |
torch.empty() |
快速创建一个未初始化的张量; 值来自内存随机残留的数据 | torch.empty(3,4) |
.clone() |
深度拷贝一个张量, 与原张量完全独立 | b = a.clone(); 直接赋值 b=a 则会指向一个内存, 修改时同时变动 |
torch.arange(start, end, step) |
生成从 start 到 end 的左闭右开、步长为 step 的一维张量 | tensor(0, 10, 2): tensor([0, 2, 4, 6, 8]) |
torch.linspace(start, end, steps) |
生成从 start 到 end (包括两端), 共 steps 个等间隔点的张量 | torch.linspace(0, 10, 5): tensor([0.0000, 2.5000, 7.5000, 10.0000]) |
torch.item() |
将只包含一个元素的张量提取为标量 | torch.tensor([1]).item() |
torch.cat((a, b), dim) |
在现有维度上拼接两个张量, 不会增加新的维度; 要制定哪个维度 | a = torch.tensor([[1,2],[3,4]])b = torch.tensor([[5,6],[7,8]])torch.cat((a,b), dim=0): tensor([[1,2],[3,4],[5,6],[7,8]])...dim=1: tensor([[1,2,3,4],[5,6,7,8]]) |
torch.stack(a, b), dim) |
将张量在新维度上堆叠起来 | torch.stack((a,b),dim=0): tensor([[[1,2],[3,4]],[[5,6],[7,8]]])...dim=1: tensor([[[1,2],[5,6]],[[3,4],[7,8]]]) |
| [^1]: (8,) 形状的一维张量, 和 (1,8) 形状的二维张量维度不同. |
1.3 随机数与抽样
| 函数 | 说明 | 用例 |
|---|---|---|
torch.rand() |
生成指定尺寸的张量, 每个元素服从 [0,1] 上的均匀分布 |
torch.rand(2,3) |
torch.randn() |
服从标准正态分布 |
|
torch.randint(low, high, size) |
生成 [low,high) 之间、指定尺寸的随机整数 |
torch.randint(low=0, high=10, size=(2,4)) |
torch.normal(mean, std, size) |
指定均值、标准差的正态分布 | |
torch.randperm() |
生成指定长度的一维随机排列; 常用于索引打乱 | torch.randperm(3): 如 tensor([2,0,1]) |
torch.multinomial(inpput, num_samples, replacement) |
定义每个类别的概率为 input, 然后采样 num_samples 次; 定义是否放回. input 可以是一维或者二维的, 代码会自动归一化 |
probs = torch.tensor([0.1, 0.6, 0.3])torch.multinomial(probs, 5, replacement=True) 有放回地采样 5 次 |
torch.manual_seed()torch.cuda.manual_seed() |
手动设置随机数种子 |
1.4 布尔操作
| 函数 | 说明 | 用例 |
|---|---|---|
> == < |
逐元素比较两个张量的元素大小 | a = torch.tensor([1,2,3])b = torch.tensor([3,2,1])print(a>b): tensor([False, False, True]) |
| 张量、标量比较时, 结果会广播到张量相同的形状 | c = torch.tensor([[1,2],[3,4]])print(c>2): tensor([[False, False],[True, True]]) |
|
torch.any() |
只要有一个为 True 就返回 True; 可以指定 dim |
torch.any(c > 3, dim=1): tensor([False, True]) |
torch.all() |
所有元素都为 True 才返回 True; 可以指定 dim |
|
torch.nonzero() |
返回一个张量: 每一行是一个 True 元素的索引 | x = torch.tensor([5, 0, 3, 0, 8])torch.nonzero(x>0): tensor([[0], [2], [4]])二维情形: y = torch.tensor([[0, 5, 0], [2, 0, 9]])torch.nonzero(y>0): torch.tensor([[0, 1], [1, 0], [1, 2]]) |
torch.where(condition, x, y) |
根据 condition 条件, 如果为 True 选 x, 否则选 y |
x = torch.tensor([[1, 2, 3], [3, 4, 5]])y = torch.tensor([[-1, -2, -3], [-4, -5, -1]])torch.where(x > 3, x, y): tensor([[-1, -2, -3], [-4, 4, 5]]) |
布尔张量常用来作为掩码:
x = torch.tensor([10, 20, 30, 40, 50])
mask = x > 25 # tensor([False, False, True, True, True])
selected = x[mask] # tensor([30, 40, 50])
y = torch.tensor([[1,6], [8,3], [2,9]])
mask = y > 5 # tensor([[False, True], [True, False], [False, True]])
selected = y[mask] # tensor([6, 8, 9])
1.5 索引
取
a = tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
| 切片 | 说明 | 输出 (略去 tensor()) |
|---|---|---|
a |
原始结果 | 略 |
a[0] |
第 0 行 | [1, 2, 3, 4] |
a[0][1] 或 a[0,1] |
第 0 行第 1 列 | 2 |
a[0:3] |
第 0~2 行 | [[1,2,3,4], [5,6,7,8], [9,10,11,12]] |
a[0:3][1] |
第 0~2 行第 1 列 | [5,6,7,8] |
a[3:] |
第 3~最后 1 行 | [13,14,15,16] |
a[0:3:2] |
第 0、2 行 | [[1,2,3,4], [9,10,11,12]] |
a[-2:] |
倒数第 2行~最后 1 行 | [[9,10,11,12],[13,14,15,16]] |
切片索引创建的依然是视图, 也即改变会一起发生.
另外还有一种花式索引, 返回副本.
| 切片 | 说明 | 输出 |
|---|---|---|
a[[3,1]] |
输出第 3 行、第 1 行 | [[13,14,15,16], [5,6,7,8]] |
a[[3,1], 2] |
输出第 3 行、第 1 行的第 0 列 | [15, 7] |
1.6 变换
| 函数 | 说明 | 用例 |
|---|---|---|
.view() |
改变形状为指定的尺寸; 要求新形状与原始张量在内存上一致; 共享内存 | x.view(2, 8) 修改为 (2,8)x.view(16) 修改为 (16,) 的一维张量x.view(-1,2): 修改为 (8,2). -1 表示自动适配大小 |
.reshape() |
同样改变形状, 但没有内存一致性要求, 更稳健; 如果内存一致会共享内存, 否则创建副本 | |
.unsqueeze(dim) |
在指定的 dim 处插入大小为 1 的新维度 |
x = torch.tensor([1,2,3])x.unsqueeze(dim=0): (3,)->(1,3)x.unsqueeze(dim=1): (3,)->(3,1) |
.squeeze(dim) |
移除指定维度上大小为 1 的维度; 如果不指定移除所有大小为 1 的维度 | y = torch.randn(1, 2, 1, 3)y.squeeze(): (2,3)y.squeeze(dim=2): (1,2,3) |
.flatten() |
拉平为一维向量 | x = torch.randn(2,2,3)x.flatten(): (24,) |
.t() |
二维张量的转置 | |
.transpose(dim0, dim1) |
交换指定的两个维度 | |
.permute(*dims)[1] |
按指定顺序重排所有维度 | |
torch.rot90(a, k, dims) |
逆时针旋转 90 度 * k 次 (负数则改为顺时针); 用 dims 指定旋转的平面, 一般只会用到 2 维矩阵的旋转 |
a = torch.tensor([[1,2],[3,4]])torch.rot90(a, k=1, dims=[0,1]): tensor([[2,4], [1,3]]) |
1.7 基本运算
-
两个张量可以进行逐元素运算:
+ - * / **(注意乘法是逐元素相乘, 不是矩阵乘法)
torch.sqrt()逐元素开根
其他函数 (torch开头):log log2 log10 expsin cos tanpi e等广播机制: 当两个张量形状不同时, 会自动扩展维度让它们相容.
-
聚合运算:
sum mean max min prod(最后一个是全局乘积)std var medianargmax argmin. 可以指定dim.dim指定的参数维度会被聚合, 除非指定keepdim=True.matrix = torch.tensor([[1, -2, 3], [4, 5, 2]]) argmax = torch.argmax(matrix, dim=1, keepdim=True) # tensor([[2], [1]]) argmin = torch.argmin(matrix, dim=1) # tensor([1, 2]) -
线性代数相关:
linalg.inv linalg.det linalg.matrix_rank tracelinalg.eig(同时返回特征值和特征向量)torch.dot点积
torch.norm(p, dim, keepdim)范数. p=2就是欧式距离,p=1就是绝对值之和 -
其他
abs sign floor ceil round trunc
torch.clamp(min, max)将每一个值限定在min, max之间 -
为了节省内存, 很多操作都有原地版本:
.add_ .sub_ .mul_ .div_ .pow_ .sqrt_.zero_.fill_.abs_.log_.sin_.tan_.log2_...
transposepermute与原始数据共享内存, 但数据的存储顺序会变为不连续, 所以view可能会报错. ↩︎